feat(lora): save/restore LoRA config in checkpoint metadata#4269
feat(lora): save/restore LoRA config in checkpoint metadata#4269RexBearIU wants to merge 1 commit into
Conversation
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
shralex
left a comment
There was a problem hiding this comment.
Thanks Jackie! A significant thing missing in this PR is using the metadata file on checkpoint restore path.
187905b to
cd17578
Compare
|
Hi @shralex, thank you for the feedback! I have fully addressed your comments with the following changes:
Please let me know if you would like any other enhancements! |
cd17578 to
1b15640
Compare
1b15640 to
ae44adc
Compare
69c78a7 to
a701719
Compare
a701719 to
07c5e19
Compare
added the logic to re-use the metadata for checkpoint restore. |
5940e65 to
9bc253e
Compare
9bc253e to
0f6248b
Compare
shralex
left a comment
There was a problem hiding this comment.
This version reverts Xibin's previous version where sync_lora_metadata was in lora_utils. We should move it back there and use it not just on checkpoint conversion but also before model creation.
| replicator_error_handler(config) | ||
| return checkpoint_manager.save(step, args=Composite(state=checkpoint_args), force=force) | ||
| return checkpoint_manager.save( | ||
| step, args=Composite(state=checkpoint_args), force=force, custom_metadata=custom_metadata |
There was a problem hiding this comment.
EmergencyCheckpointManager and EmergencyReplicatorCheckpointManager do not accept a custom metadata argument. Lets leave this argument out here, and open a bug to add this support
There was a problem hiding this comment.
Done! Omitted passing the custom_metadata argument when calling .save() on EmergencyCheckpointManager or EmergencyReplicatorCheckpointManager.
There was a problem hiding this comment.
I've created a bug b/529671188 for Orbax team to add support on EmergencyCheckpointManager or EmergencyReplicatorCheckpointManager
d55b90d to
ffe10de
Compare
Done. Moved |
2649217 to
e58e177
Compare
c3b3d77 to
e8f3545
Compare
Co-authored-by: Xibin Liu <xibin@google.com>
e8f3545 to
ba410a3
Compare
Description
This PR implements native serialization of LoRA configuration parameters (
lora_rank,lora_alpha) in standard Orbax_CHECKPOINT_METADATAfiles, and automatically restores them during checkpoint-to-Hugging Face conversion.Why is this change being made?
Previously, users had to manually supply matching
lora.lora_rankandlora.lora_alphaparameters when converting MaxText checkpoints to Hugging Face format. Storing them in Orbax metadata makes the conversion seamless and error-free (resolves @igorts-git's request in #3970).Key Implementation Details
save_checkpoint(checkpointing.py), we save the activeconfig.lorablock under the"lora"key in Orbax'scustom_metadatawhen a LoRA rank is specified.main(to_huggingface.py),sync_lora_metadatareads the custom metadata fromlora_restore_pathviaocp.StandardCheckpointerand overrides active config parameters during conversion.hf_checkpoint_conversion_test.pyto move dynamically loaded inline imports to global top-level imports and completely removedjsonimport since JSON string is written directly.BUGS: #3970
Tests
We have verified the implementation with complete suite-level and individual unit-tests:
SyncLoRAMetadataTestintests/unit/hf_checkpoint_conversion_test.pyto verify the auto-resolving mechanism during Hugging Face conversion.python tests/unit/hf_checkpoint_conversion_test.pyAll tests pass successfully.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.